{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc80121e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pip install transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "678329aa-5aa5-471a-b23a-be009fe09695",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from transformers import GPT2Tokenizer, GPT2Model\n",
    "\n",
    "from tensor_visualizer import TensorVisualizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71571f43-b453-4173-a016-26a833b98c48",
   "metadata": {},
   "outputs": [],
   "source": [
    "text = \"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.\"\n",
    "\n",
    "tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n",
    "model = GPT2Model.from_pretrained('gpt2')\n",
    "encoded_input = tokenizer(text, return_tensors='pt')\n",
    "output = model(**encoded_input, output_attentions=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "615edef3-032a-4473-b5b5-ac779eb33d40",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens = tokenizer.convert_ids_to_tokens(encoded_input.input_ids[0])\n",
    "all_attentions = np.array([m[0].detach().cpu().numpy() for m in output.attentions])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62b60948-6774-4b27-ba21-0e91979b0570",
   "metadata": {},
   "outputs": [],
   "source": [
    "TensorVisualizer(all_attentions, names=[\"layer\", \"head\", \"query\", \"key\"], labels=[None, None, tokens, tokens])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb693716-d573-4a16-b94a-31661ff5b4da",
   "metadata": {},
   "outputs": [],
   "source": [
    "TensorVisualizer(all_attentions, names=[\"layer\", \"head\", \"query\", \"key\"], labels=[None, None, tokens, tokens], permute=[2, 0, 1, 3])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
